{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import json\n",
    "import os\n",
    "import shutil\n",
    "import cv2\n",
    "from PIL import Image, ImageDraw\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "from tqdm import tqdm\n",
    "\n",
    "import sys\n",
    "sys.path.append('/home/dan/work/cocoapi/PythonAPI/')\n",
    "\n",
    "from pycocotools.coco import COCO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_DIR = '/home/dan/datasets/COCO'\n",
    "\n",
    "DATA_TYPE = 'train2017'\n",
    "RESULT_PATH = '/home/dan/datasets/COCO/train_annotations/'\n",
    "\n",
    "# DATA_TYPE = 'val2017'\n",
    "# RESULT_PATH = '/home/dan/datasets/COCO/val_annotations/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coco = COCO('{}/annotations/instances_{}.json'.format(DATA_DIR, DATA_TYPE))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Show categories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "categories = coco.loadCats(coco.getCatIds())\n",
    "coco_id_to_name = {c['id']: c['name'] for c in categories}\n",
    "names = [cat['name'] for cat in categories]\n",
    "print('COCO categories: \\n{}'.format(' '.join(names)))\n",
    "print('\\nnumber of labels:', len(names))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create label encoding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# number of row - integer encoding\n",
    "if not os.path.exists('coco_labels.txt'):\n",
    "    with open('coco_labels.txt', 'w') as f:\n",
    "        for n in names:\n",
    "            f.write(n + '\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Show an image with full annotation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "catIds = coco.getCatIds(catNms=['person', 'zebra'])\n",
    "imgIds = coco.getImgIds(catIds=catIds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i = np.random.randint(0, len(imgIds))\n",
    "metadata = coco.loadImgs(imgIds[i])[0]\n",
    "image = cv2.imread('%s/images/%s/%s' % (DATA_DIR, DATA_TYPE, metadata['file_name']))\n",
    "image = image[:, :, [2, 1, 0]]\n",
    "annIds = coco.getAnnIds(imgIds=metadata['id'], catIds=catIds, iscrowd=False)\n",
    "annotations = coco.loadAnns(annIds)\n",
    "\n",
    "plt.imshow(image)\n",
    "plt.axis('off')\n",
    "coco.showAnns(annotations)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Show bounding box annotation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_boxes_on_image(image, annotations):\n",
    "\n",
    "    image = Image.fromarray(image)\n",
    "    draw = ImageDraw.Draw(image, 'RGBA')\n",
    "    width, height = image.size\n",
    "\n",
    "    for a in annotations:\n",
    "        xmin, ymin, w, h = a['bbox']\n",
    "        xmax, ymax = xmin + w, ymin + h\n",
    "\n",
    "        fill = (255, 255, 255, 45)\n",
    "        outline = 'red'\n",
    "        draw.rectangle(\n",
    "            [(xmin, ymin), (xmax, ymax)],\n",
    "            fill=fill, outline=outline\n",
    "        )\n",
    "    return image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "draw_boxes_on_image(image, annotations)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convert"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get a particular class:\n",
    "# catIds = coco.getCatIds(catNms=['person'])\n",
    "# imgIds = coco.getImgIds(catIds=catIds)\n",
    "\n",
    "# get all:\n",
    "catIds = coco.getCatIds()\n",
    "imgIds = coco.getImgIds()\n",
    "\n",
    "print('number of images:', len(imgIds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_annotation(i):\n",
    "    metadata = coco.loadImgs(i)[0]\n",
    "    annIds = coco.getAnnIds(imgIds=metadata['id'], catIds=catIds, iscrowd=False)\n",
    "    height, width = metadata['height'], metadata['width']\n",
    "    annotation = {\n",
    "      \"filename\": metadata['file_name'],\n",
    "      \"size\": {\"depth\": 3, \"width\": width, \"height\": height}\n",
    "    }\n",
    "    objects = []\n",
    "    for a in coco.loadAnns(annIds):\n",
    "        label = coco_id_to_name[a['category_id']]\n",
    "        xmin, ymin, w, h = a['bbox']\n",
    "        xmax, ymax = xmin + w, ymin + h\n",
    "        \n",
    "        ymin = min(ymin, ymax)\n",
    "        xmin = min(xmin, xmax)\n",
    "        ymax = max(ymin, ymax)\n",
    "        xmax = max(xmin, xmax)\n",
    "        \n",
    "        ymin = min(max(0, ymin), height)\n",
    "        xmin = min(max(0, xmin), width)\n",
    "        ymax = max(min(height, ymax), 0)\n",
    "        xmax = max(min(width, xmax), 0)\n",
    "        \n",
    "        if (ymax - ymin) < 1 or (xmax - xmin) < 1:\n",
    "            continue\n",
    "\n",
    "        objects.append({\"bndbox\": {\"ymin\": ymin, \"ymax\": ymax, \"xmax\": xmax, \"xmin\": xmin}, \"name\": label})\n",
    "\n",
    "    annotation[\"object\"] = objects\n",
    "    return annotation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "shutil.rmtree(RESULT_PATH, ignore_errors=True)\n",
    "os.mkdir(RESULT_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for i in tqdm(imgIds):\n",
    "    d = get_annotation(i)\n",
    "    filename = d['filename']\n",
    "    assert filename.endswith('.jpg')\n",
    "    name = filename[:-4]\n",
    "    json.dump(d, open(os.path.join(RESULT_PATH, name + '.json'), 'w')) "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
